//===--- InstOptUtils.cpp - PILOptimizer instruction utilities ------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

#include "polarphp/pil/optimizer/utils/InstOptUtils.h"
#include "polarphp/ast/GenericSignature.h"
#include "polarphp/ast/SubstitutionMap.h"
#include "polarphp/ast/SemanticAttrs.h"
#include "polarphp/pil/lang/BasicBlockUtils.h"
#include "polarphp/pil/lang/DebugUtils.h"
#include "polarphp/pil/lang/InstructionUtils.h"
#include "polarphp/pil/lang/PILArgument.h"
#include "polarphp/pil/lang/PILBuilder.h"
#include "polarphp/pil/lang/PILModule.h"
#include "polarphp/pil/lang/PILUndef.h"
#include "polarphp/pil/lang/TypeLowering.h"
#include "polarphp/pil/optimizer/analysis/ARCAnalysis.h"
#include "polarphp/pil/optimizer/analysis/DominanceAnalysis.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Compiler.h"
#include <deque>

using namespace polar;

static llvm::cl::opt<bool> EnableExpandAll("enable-expand-all",
                                           llvm::cl::init(false));

/// Creates an increment on \p Ptr before insertion point \p InsertPt that
/// creates a strong_retain if \p Ptr has reference semantics itself or a
/// retain_value if \p Ptr is a non-trivial value without reference-semantics.
NullablePtr<PILInstruction>
polar::createIncrementBefore(PILValue ptr, PILInstruction *insertPt) {
   // Set up the builder we use to insert at our insertion point.
   PILBuilder builder(insertPt);
   auto loc = RegularLocation::getAutoGeneratedLocation();

   // If we have a trivial type, just bail, there is no work to do.
   if (ptr->getType().isTrivial(builder.getFunction()))
      return nullptr;

   // If Ptr is refcounted itself, create the strong_retain and
   // return.
   if (ptr->getType().isReferenceCounted(builder.getModule())) {
      if (ptr->getType().is<UnownedStorageType>())
         return builder.createUnownedRetain(loc, ptr,
                                            builder.getDefaultAtomicity());
      else
         return builder.createStrongRetain(loc, ptr,
                                           builder.getDefaultAtomicity());
   }

   // Otherwise, create the retain_value.
   return builder.createRetainValue(loc, ptr, builder.getDefaultAtomicity());
}

/// Creates a decrement on \p ptr before insertion point \p InsertPt that
/// creates a strong_release if \p ptr has reference semantics itself or
/// a release_value if \p ptr is a non-trivial value without
/// reference-semantics.
NullablePtr<PILInstruction>
polar::createDecrementBefore(PILValue ptr, PILInstruction *insertPt) {
   // Setup the builder we will use to insert at our insertion point.
   PILBuilder builder(insertPt);
   auto loc = RegularLocation::getAutoGeneratedLocation();

   if (ptr->getType().isTrivial(builder.getFunction()))
      return nullptr;

   // If ptr has reference semantics itself, create a strong_release.
   if (ptr->getType().isReferenceCounted(builder.getModule())) {
      if (ptr->getType().is<UnownedStorageType>())
         return builder.createUnownedRelease(loc, ptr,
                                             builder.getDefaultAtomicity());
      else
         return builder.createStrongRelease(loc, ptr,
                                            builder.getDefaultAtomicity());
   }

   // Otherwise create a release value.
   return builder.createReleaseValue(loc, ptr, builder.getDefaultAtomicity());
}

/// Perform a fast local check to see if the instruction is dead.
///
/// This routine only examines the state of the instruction at hand.
bool polar::isInstructionTriviallyDead(PILInstruction *inst) {
   // At Onone, consider all uses, including the debug_info.
   // This way, debug_info is preserved at Onone.
   if (inst->hasUsesOfAnyResult()
       && inst->getFunction()->getEffectiveOptimizationMode()
          <= OptimizationMode::NoOptimization)
      return false;

   if (!onlyHaveDebugUsesOfAllResults(inst) || isa<TermInst>(inst))
      return false;

   if (auto *bi = dyn_cast<BuiltinInst>(inst)) {
      // Although the onFastPath builtin has no side-effects we don't want to
      // remove it.
      if (bi->getBuiltinInfo().ID == BuiltinValueKind::OnFastPath)
         return false;
      return !bi->mayHaveSideEffects();
   }

   // condfail instructions that obviously can't fail are dead.
   if (auto *cfi = dyn_cast<CondFailInst>(inst))
      if (auto *ili = dyn_cast<IntegerLiteralInst>(cfi->getOperand()))
         if (!ili->getValue())
            return true;

   // mark_uninitialized is never dead.
   if (isa<MarkUninitializedInst>(inst))
      return false;

   if (isa<DebugValueInst>(inst) || isa<DebugValueAddrInst>(inst))
      return false;

   // These invalidate enums so "write" memory, but that is not an essential
   // operation so we can remove these if they are trivially dead.
   if (isa<UncheckedTakeEnumDataAddrInst>(inst))
      return true;

   if (!inst->mayHaveSideEffects())
      return true;

   return false;
}

/// Return true if this is a release instruction and the released value
/// is a part of a guaranteed parameter.
bool polar::isIntermediateRelease(PILInstruction *inst,
                                  EpilogueARCFunctionInfo *eafi) {
   // Check whether this is a release instruction.
   if (!isa<StrongReleaseInst>(inst) && !isa<ReleaseValueInst>(inst))
      return false;

   // OK. we have a release instruction.
   // Check whether this is a release on part of a guaranteed function argument.
   PILValue Op = stripValueProjections(inst->getOperand(0));
   auto *arg = dyn_cast<PILFunctionArgument>(Op);
   if (!arg)
      return false;

   // This is a release on a guaranteed parameter. Its not the final release.
   if (arg->hasConvention(PILArgumentConvention::Direct_Guaranteed))
      return true;

   // This is a release on an owned parameter and its not the epilogue release.
   // Its not the final release.
   auto rel = eafi->computeEpilogueARCInstructions(
      EpilogueARCContext::EpilogueARCKind::Release, arg);
   if (rel.size() && !rel.count(inst))
      return true;

   // Failed to prove anything.
   return false;
}

namespace {
using CallbackTy = llvm::function_ref<void(PILInstruction *)>;
} // end anonymous namespace

void polar::recursivelyDeleteTriviallyDeadInstructions(
   ArrayRef<PILInstruction *> ia, bool force, CallbackTy callback) {
   // Delete these instruction and others that become dead after it's deleted.
   llvm::SmallPtrSet<PILInstruction *, 8> deadInsts;
   for (auto *inst : ia) {
      // If the instruction is not dead and force is false, do nothing.
      if (force || isInstructionTriviallyDead(inst))
         deadInsts.insert(inst);
   }
   llvm::SmallPtrSet<PILInstruction *, 8> nextInsts;
   while (!deadInsts.empty()) {
      for (auto inst : deadInsts) {
         // Call the callback before we mutate the to be deleted instruction in any
         // way.
         callback(inst);

         // Check if any of the operands will become dead as well.
         MutableArrayRef<Operand> operands = inst->getAllOperands();
         for (Operand &operand : operands) {
            PILValue operandVal = operand.get();
            if (!operandVal)
               continue;

            // Remove the reference from the instruction being deleted to this
            // operand.
            operand.drop();

            // If the operand is an instruction that is only used by the instruction
            // being deleted, delete it.
            if (auto *operandValInst = operandVal->getDefiningInstruction())
               if (!deadInsts.count(operandValInst)
                   && isInstructionTriviallyDead(operandValInst))
                  nextInsts.insert(operandValInst);
         }

         // If we have a function ref inst, we need to especially drop its function
         // argument so that it gets a proper ref decrement.
         auto *fri = dyn_cast<FunctionRefInst>(inst);
         if (fri && fri->getInitiallyReferencedFunction())
            fri->dropReferencedFunction();

         auto *dfri = dyn_cast<DynamicFunctionRefInst>(inst);
         if (dfri && dfri->getInitiallyReferencedFunction())
            dfri->dropReferencedFunction();

         auto *pfri = dyn_cast<PreviousDynamicFunctionRefInst>(inst);
         if (pfri && pfri->getInitiallyReferencedFunction())
            pfri->dropReferencedFunction();
      }

      for (auto inst : deadInsts) {
         // This will remove this instruction and all its uses.
         eraseFromParentWithDebugInsts(inst, callback);
      }

      nextInsts.swap(deadInsts);
      nextInsts.clear();
   }
}

/// If the given instruction is dead, delete it along with its dead
/// operands.
///
/// \param inst The instruction to be deleted.
/// \param force If force is set, don't check if the top level instruction is
///        considered dead - delete it regardless.
void polar::recursivelyDeleteTriviallyDeadInstructions(PILInstruction *inst,
                                                       bool force,
                                                       CallbackTy callback) {
   ArrayRef<PILInstruction *> ai = ArrayRef<PILInstruction *>(inst);
   recursivelyDeleteTriviallyDeadInstructions(ai, force, callback);
}

void polar::eraseUsesOfInstruction(PILInstruction *inst, CallbackTy callback) {
   for (auto result : inst->getResults()) {
      while (!result->use_empty()) {
         auto ui = result->use_begin();
         auto *user = ui->getUser();
         assert(user && "User should never be NULL!");

         // If the instruction itself has any uses, recursively zap them so that
         // nothing uses this instruction.
         eraseUsesOfInstruction(user, callback);

         // Walk through the operand list and delete any random instructions that
         // will become trivially dead when this instruction is removed.

         for (auto &operand : user->getAllOperands()) {
            if (auto *operandI = operand.get()->getDefiningInstruction()) {
               // Don't recursively delete the instruction we're working on.
               // FIXME: what if we're being recursively invoked?
               if (operandI != inst) {
                  operand.drop();
                  recursivelyDeleteTriviallyDeadInstructions(operandI, false,
                                                             callback);
               }
            }
         }
         callback(user);
         user->eraseFromParent();
      }
   }
}

void polar::collectUsesOfValue(PILValue v,
                               llvm::SmallPtrSetImpl<PILInstruction *> &insts) {
   for (auto ui = v->use_begin(), E = v->use_end(); ui != E; ui++) {
      auto *user = ui->getUser();
      // Instruction has been processed.
      if (!insts.insert(user).second)
         continue;

      // Collect the users of this instruction.
      for (auto result : user->getResults())
         collectUsesOfValue(result, insts);
   }
}

void polar::eraseUsesOfValue(PILValue v) {
   llvm::SmallPtrSet<PILInstruction *, 4> insts;
   // Collect the uses.
   collectUsesOfValue(v, insts);
   // Erase the uses, we can have instructions that become dead because
   // of the removal of these instructions, leave to DCE to cleanup.
   // Its not safe to do recursively delete here as some of the PILInstruction
   // maybe tracked by this set.
   for (auto inst : insts) {
      inst->replaceAllUsesOfAllResultsWithUndef();
      inst->eraseFromParent();
   }
}

// Devirtualization of functions with covariant return types produces
// a result that is not an apply, but takes an apply as an
// argument. Attempt to dig the apply out from this result.
FullApplySite polar::findApplyFromDevirtualizedResult(PILValue v) {
   if (auto Apply = FullApplySite::isa(v))
      return Apply;

   if (isa<UpcastInst>(v) || isa<EnumInst>(v) || isa<UncheckedRefCastInst>(v))
      return findApplyFromDevirtualizedResult(
         cast<SingleValueInstruction>(v)->getOperand(0));

   return FullApplySite();
}

bool polar::mayBindDynamicSelf(PILFunction *F) {
   if (!F->hasSelfMetadataParam())
      return false;

   PILValue mdArg = F->getSelfMetadataArgument();

   for (Operand *mdUse : F->getSelfMetadataArgument()->getUses()) {
      PILInstruction *mdUser = mdUse->getUser();
      for (Operand &typeDepOp : mdUser->getTypeDependentOperands()) {
         if (typeDepOp.get() == mdArg)
            return true;
      }
   }
   return false;
}

static PILValue skipAddrProjections(PILValue v) {
   for (;;) {
      switch (v->getKind()) {
         case ValueKind::IndexAddrInst:
         case ValueKind::IndexRawPointerInst:
         case ValueKind::StructElementAddrInst:
         case ValueKind::TupleElementAddrInst:
            v = cast<SingleValueInstruction>(v)->getOperand(0);
            break;
         default:
            return v;
      }
   }
   llvm_unreachable("there is no escape from an infinite loop");
}

/// Check whether the \p addr is an address of a tail-allocated array element.
bool polar::isAddressOfArrayElement(PILValue addr) {
   addr = stripAddressProjections(addr);
   if (auto *md = dyn_cast<MarkDependenceInst>(addr))
      addr = stripAddressProjections(md->getValue());

   // High-level PIL: check for an get_element_address array semantics call.
   if (auto *ptrToAddr = dyn_cast<PointerToAddressInst>(addr))
      if (auto *sei = dyn_cast<StructExtractInst>(ptrToAddr->getOperand())) {
         ArraySemanticsCall call(sei->getOperand());
         if (call && call.getKind() == ArrayCallKind::kGetElementAddress)
            return true;
      }

   // Check for an tail-address (of an array buffer object).
   if (isa<RefTailAddrInst>(skipAddrProjections(addr)))
      return true;

   return false;
}

/// Find a new position for an ApplyInst's FuncRef so that it dominates its
/// use. Not that FunctionRefInsts may be shared by multiple ApplyInsts.
void polar::placeFuncRef(ApplyInst *ai, DominanceInfo *domInfo) {
   FunctionRefInst *funcRef = cast<FunctionRefInst>(ai->getCallee());
   PILBasicBlock *domBB = domInfo->findNearestCommonDominator(
      ai->getParent(), funcRef->getParent());
   if (domBB == ai->getParent() && domBB != funcRef->getParent())
      // Prefer to place the FuncRef immediately before the call. Since we're
      // moving FuncRef up, this must be the only call to it in the block.
      funcRef->moveBefore(ai);
   else
      // Otherwise, conservatively stick it at the beginning of the block.
      funcRef->moveBefore(&*domBB->begin());
}

/// Add an argument, \p val, to the branch-edge that is pointing into
/// block \p Dest. Return a new instruction and do not erase the old
/// instruction.
TermInst *polar::addArgumentToBranch(PILValue val, PILBasicBlock *dest,
                                     TermInst *branch) {
   PILBuilderWithScope builder(branch);

   if (auto *cbi = dyn_cast<CondBranchInst>(branch)) {
      SmallVector<PILValue, 8> trueArgs;
      SmallVector<PILValue, 8> falseArgs;

      for (auto arg : cbi->getTrueArgs())
         trueArgs.push_back(arg);

      for (auto arg : cbi->getFalseArgs())
         falseArgs.push_back(arg);

      if (dest == cbi->getTrueBB()) {
         trueArgs.push_back(val);
         assert(trueArgs.size() == dest->getNumArguments());
      } else {
         falseArgs.push_back(val);
         assert(falseArgs.size() == dest->getNumArguments());
      }

      return builder.createCondBranch(
         cbi->getLoc(), cbi->getCondition(), cbi->getTrueBB(), trueArgs,
         cbi->getFalseBB(), falseArgs, cbi->getTrueBBCount(),
         cbi->getFalseBBCount());
   }

   if (auto *bi = dyn_cast<BranchInst>(branch)) {
      SmallVector<PILValue, 8> args;

      for (auto arg : bi->getArgs())
         args.push_back(arg);

      args.push_back(val);
      assert(args.size() == dest->getNumArguments());
      return builder.createBranch(bi->getLoc(), bi->getDestBB(), args);
   }

   llvm_unreachable("unsupported terminator");
}

PILLinkage polar::getSpecializedLinkage(PILFunction *f, PILLinkage linkage) {
   if (hasPrivateVisibility(linkage) && !f->isSerialized()) {
      // Specializations of private symbols should remain so, unless
      // they were serialized, which can only happen when specializing
      // definitions from a standard library built with -sil-serialize-all.
      return PILLinkage::Private;
   }

   return PILLinkage::Shared;
}

/// Cast a value into the expected, ABI compatible type if necessary.
/// This may happen e.g. when:
/// - a type of the return value is a subclass of the expected return type.
/// - actual return type and expected return type differ in optionality.
/// - both types are tuple-types and some of the elements need to be casted.
///
/// If CheckOnly flag is set, then this function only checks if the
/// required casting is possible. If it is not possible, then None
/// is returned.
///
/// If CheckOnly is not set, then a casting code is generated and the final
/// casted value is returned.
///
/// NOTE: We intentionally combine the checking of the cast's handling
/// possibility and the transformation performing the cast in the same function,
/// to avoid any divergence between the check and the implementation in the
/// future.
///
/// NOTE: The implementation of this function is very closely related to the
/// rules checked by PILVerifier::requireABICompatibleFunctionTypes.
PILValue polar::castValueToABICompatibleType(PILBuilder *builder,
                                             PILLocation loc, PILValue value,
                                             PILType srcTy, PILType destTy) {

   // No cast is required if types are the same.
   if (srcTy == destTy)
      return value;

   assert(srcTy.isAddress() == destTy.isAddress()
          && "Addresses aren't compatible with values");

   if (srcTy.isAddress() && destTy.isAddress()) {
      // Cast between two addresses and that's it.
      return builder->createUncheckedAddrCast(loc, value, destTy);
   }

   // If both types are classes and dest is the superclass of src,
   // simply perform an upcast.
   if (destTy.isExactSuperclassOf(srcTy)) {
      return builder->createUpcast(loc, value, destTy);
   }

   if (srcTy.isHeapObjectReferenceType() && destTy.isHeapObjectReferenceType()) {
      return builder->createUncheckedRefCast(loc, value, destTy);
   }

   if (auto mt1 = srcTy.getAs<AnyMetatypeType>()) {
      if (auto mt2 = destTy.getAs<AnyMetatypeType>()) {
         if (mt1->getRepresentation() == mt2->getRepresentation()) {
            // If builder.Type needs to be casted to A.Type and
            // A is a superclass of builder, then it can be done by means
            // of a simple upcast.
            if (mt2.getInstanceType()->isExactSuperclassOf(mt1.getInstanceType())) {
               return builder->createUpcast(loc, value, destTy);
            }

            // Cast between two metatypes and that's it.
            return builder->createUncheckedBitCast(loc, value, destTy);
         }
      }
   }

   // Check if src and dest types are optional.
   auto optionalSrcTy = srcTy.getOptionalObjectType();
   auto optionalDestTy = destTy.getOptionalObjectType();

   // Both types are optional.
   if (optionalDestTy && optionalSrcTy) {
      // If both wrapped types are classes and dest is the superclass of src,
      // simply perform an upcast.
      if (optionalDestTy.isExactSuperclassOf(optionalSrcTy)) {
         // Insert upcast.
         return builder->createUpcast(loc, value, destTy);
      }

      // Unwrap the original optional value.
      auto *someDecl = builder->getAstContext().getOptionalSomeDecl();
      auto *noneBB = builder->getFunction().createBasicBlock();
      auto *someBB = builder->getFunction().createBasicBlock();
      auto *curBB = builder->getInsertionPoint()->getParent();

      auto *contBB = curBB->split(builder->getInsertionPoint());
      contBB->createPhiArgument(destTy, ValueOwnershipKind::Owned);

      SmallVector<std::pair<EnumElementDecl *, PILBasicBlock *>, 1> caseBBs;
      caseBBs.push_back(std::make_pair(someDecl, someBB));
      builder->setInsertionPoint(curBB);
      builder->createSwitchEnum(loc, value, noneBB, caseBBs);

      // Handle the Some case.
      builder->setInsertionPoint(someBB);
      PILValue unwrappedValue =
         builder->createUncheckedEnumData(loc, value, someDecl);
      // Cast the unwrapped value.
      auto castedUnwrappedValue = castValueToABICompatibleType(
         builder, loc, unwrappedValue, optionalSrcTy, optionalDestTy);
      // Wrap into optional.
      auto castedValue =
         builder->createOptionalSome(loc, castedUnwrappedValue, destTy);
      builder->createBranch(loc, contBB, {castedValue});

      // Handle the None case.
      builder->setInsertionPoint(noneBB);
      castedValue = builder->createOptionalNone(loc, destTy);
      builder->createBranch(loc, contBB, {castedValue});
      builder->setInsertionPoint(contBB->begin());

      return contBB->getArgument(0);
   }

   // Src is not optional, but dest is optional.
   if (!optionalSrcTy && optionalDestTy) {
      auto optionalSrcCanTy =
         OptionalType::get(srcTy.getAstType())->getCanonicalType();
      auto loweredOptionalSrcType =
         PILType::getPrimitiveObjectType(optionalSrcCanTy);

      // Wrap the source value into an optional first.
      PILValue wrappedValue =
         builder->createOptionalSome(loc, value, loweredOptionalSrcType);
      // Cast the wrapped value.
      return castValueToABICompatibleType(builder, loc, wrappedValue,
                                          wrappedValue->getType(), destTy);
   }

   // Handle tuple types.
   // Extract elements, cast each of them, create a new tuple.
   if (auto srcTupleTy = srcTy.getAs<TupleType>()) {
      SmallVector<PILValue, 8> expectedTuple;
      for (unsigned i = 0, e = srcTupleTy->getNumElements(); i < e; i++) {
         PILValue element = builder->createTupleExtract(loc, value, i);
         // Cast the value if necessary.
         element = castValueToABICompatibleType(builder, loc, element,
                                                srcTy.getTupleElementType(i),
                                                destTy.getTupleElementType(i));
         expectedTuple.push_back(element);
      }

      return builder->createTuple(loc, destTy, expectedTuple);
   }

   // Function types are interchangeable if they're also ABI-compatible.
   if (srcTy.is<PILFunctionType>()) {
      if (destTy.is<PILFunctionType>()) {
         assert(srcTy.getAs<PILFunctionType>()->isNoEscape()
                == destTy.getAs<PILFunctionType>()->isNoEscape()
                || srcTy.getAs<PILFunctionType>()->getRepresentation()
                   != PILFunctionType::Representation::Thick
                   && "Swift thick functions that differ in escapeness are "
                      "not ABI "
                      "compatible");
         // Insert convert_function.
         return builder->createConvertFunction(loc, value, destTy,
            /*WithoutActuallyEscaping=*/false);
      }
   }

   llvm::errs() << "Source type: " << srcTy << "\n";
   llvm::errs() << "Destination type: " << destTy << "\n";
   llvm_unreachable("Unknown combination of types for casting");
}

ProjectBoxInst *polar::getOrCreateProjectBox(AllocBoxInst *abi,
                                             unsigned index) {
   PILBasicBlock::iterator iter(abi);
   iter++;
   assert(iter != abi->getParent()->end()
          && "alloc_box cannot be the last instruction of a block");
   PILInstruction *nextInst = &*iter;
   if (auto *pbi = dyn_cast<ProjectBoxInst>(nextInst)) {
      if (pbi->getOperand() == abi && pbi->getFieldIndex() == index)
         return pbi;
   }

   PILBuilder builder(nextInst);
   return builder.createProjectBox(abi->getLoc(), abi, index);
}

// Peek through trivial Enum initialization, typically for pointless
// Optionals.
//
// Given an UncheckedTakeEnumDataAddrInst, check that there are no
// other uses of the Enum value and return the address used to initialized the
// enum's payload:
//
//   %stack_adr = alloc_stack
//   %data_adr  = init_enum_data_addr %stk_adr
//   %enum_adr  = inject_enum_addr %stack_adr
//   %copy_src  = unchecked_take_enum_data_addr %enum_adr
//   dealloc_stack %stack_adr
//   (No other uses of %stack_adr.)
InitEnumDataAddrInst *
polar::findInitAddressForTrivialEnum(UncheckedTakeEnumDataAddrInst *utedai) {
   auto *asi = dyn_cast<AllocStackInst>(utedai->getOperand());
   if (!asi)
      return nullptr;

   PILInstruction *singleUser = nullptr;
   for (auto use : asi->getUses()) {
      auto *user = use->getUser();
      if (user == utedai)
         continue;

      // As long as there's only one UncheckedTakeEnumDataAddrInst and one
      // InitEnumDataAddrInst, we don't care how many InjectEnumAddr and
      // DeallocStack users there are.
      if (isa<InjectEnumAddrInst>(user) || isa<DeallocStackInst>(user))
         continue;

      if (singleUser)
         return nullptr;

      singleUser = user;
   }
   if (!singleUser)
      return nullptr;

   // Assume, without checking, that the returned InitEnumDataAddr dominates the
   // given UncheckedTakeEnumDataAddrInst, because that's how PIL is defined. I
   // don't know where this is actually verified.
   return dyn_cast<InitEnumDataAddrInst>(singleUser);
}

//===----------------------------------------------------------------------===//
//                       String Concatenation Optimizer
//===----------------------------------------------------------------------===//

namespace {
/// This is a helper class that performs optimization of string literals
/// concatenation.
class StringConcatenationOptimizer {
   /// Apply instruction being optimized.
   ApplyInst *ai;
   /// Builder to be used for creation of new instructions.
   PILBuilder &builder;
   /// Left string literal operand of a string concatenation.
   StringLiteralInst *sliLeft = nullptr;
   /// Right string literal operand of a string concatenation.
   StringLiteralInst *sliRight = nullptr;
   /// Function used to construct the left string literal.
   FunctionRefInst *friLeft = nullptr;
   /// Function used to construct the right string literal.
   FunctionRefInst *friRight = nullptr;
   /// Apply instructions used to construct left string literal.
   ApplyInst *aiLeft = nullptr;
   /// Apply instructions used to construct right string literal.
   ApplyInst *aiRight = nullptr;
   /// String literal conversion function to be used.
   FunctionRefInst *friConvertFromBuiltin = nullptr;
   /// Result type of a function producing the concatenated string literal.
   PILValue funcResultType;

   /// Internal helper methods
   bool extractStringConcatOperands();
   void adjustEncodings();
   APInt getConcatenatedLength();
   bool isAscii() const;

public:
   StringConcatenationOptimizer(ApplyInst *ai, PILBuilder &builder)
      : ai(ai), builder(builder) {}

   /// Tries to optimize a given apply instruction if it is a
   /// concatenation of string literals.
   ///
   /// Returns a new instruction if optimization was possible.
   SingleValueInstruction *optimize();
};

} // end anonymous namespace

/// Checks operands of a string concatenation operation to see if
/// optimization is applicable.
///
/// Returns false if optimization is not possible.
/// Returns true and initializes internal fields if optimization is possible.
bool StringConcatenationOptimizer::extractStringConcatOperands() {
   auto *Fn = ai->getReferencedFunctionOrNull();
   if (!Fn)
      return false;

   if (ai->getNumArguments() != 3 || !Fn->hasSemanticsAttr(semantics::STRING_CONCAT))
      return false;

   // Left and right operands of a string concatenation operation.
   aiLeft = dyn_cast<ApplyInst>(ai->getOperand(1));
   aiRight = dyn_cast<ApplyInst>(ai->getOperand(2));

   if (!aiLeft || !aiRight)
      return false;

   friLeft = dyn_cast<FunctionRefInst>(aiLeft->getCallee());
   friRight = dyn_cast<FunctionRefInst>(aiRight->getCallee());

   if (!friLeft || !friRight)
      return false;

   auto *friLeftFun = friLeft->getReferencedFunctionOrNull();
   auto *friRightFun = friRight->getReferencedFunctionOrNull();

   if (friLeftFun->getEffectsKind() >= EffectsKind::ReleaseNone
       || friRightFun->getEffectsKind() >= EffectsKind::ReleaseNone)
      return false;

   if (!friLeftFun->hasSemanticsAttrs() || !friRightFun->hasSemanticsAttrs())
      return false;

   auto aiLeftOperandsNum = aiLeft->getNumOperands();
   auto aiRightOperandsNum = aiRight->getNumOperands();

   // makeUTF8 should have following parameters:
   // (start: RawPointer, utf8CodeUnitCount: Word, isASCII: Int1)
   if (!((friLeftFun->hasSemanticsAttr(semantics::STRING_MAKE_UTF8)
          && aiLeftOperandsNum == 5)
         || (friRightFun->hasSemanticsAttr(semantics::STRING_MAKE_UTF8)
             && aiRightOperandsNum == 5)))
      return false;

   sliLeft = dyn_cast<StringLiteralInst>(aiLeft->getOperand(1));
   sliRight = dyn_cast<StringLiteralInst>(aiRight->getOperand(1));

   if (!sliLeft || !sliRight)
      return false;

   // Only UTF-8 and UTF-16 encoded string literals are supported by this
   // optimization.
   if (sliLeft->getEncoding() != StringLiteralInst::Encoding::UTF8
       && sliLeft->getEncoding() != StringLiteralInst::Encoding::UTF16)
      return false;

   if (sliRight->getEncoding() != StringLiteralInst::Encoding::UTF8
       && sliRight->getEncoding() != StringLiteralInst::Encoding::UTF16)
      return false;

   return true;
}

/// Ensures that both string literals to be concatenated use the same
/// UTF encoding. Converts UTF-8 into UTF-16 if required.
void StringConcatenationOptimizer::adjustEncodings() {
   if (sliLeft->getEncoding() == sliRight->getEncoding()) {
      friConvertFromBuiltin = friLeft;
      if (sliLeft->getEncoding() == StringLiteralInst::Encoding::UTF8) {
         funcResultType = aiLeft->getOperand(4);
      } else {
         funcResultType = aiLeft->getOperand(3);
      }
      return;
   }

   builder.setCurrentDebugScope(ai->getDebugScope());

   // If one of the string literals is UTF8 and another one is UTF16,
   // convert the UTF8-encoded string literal into UTF16-encoding first.
   if (sliLeft->getEncoding() == StringLiteralInst::Encoding::UTF8
       && sliRight->getEncoding() == StringLiteralInst::Encoding::UTF16) {
      funcResultType = aiRight->getOperand(3);
      friConvertFromBuiltin = friRight;
      // Convert UTF8 representation into UTF16.
      sliLeft = builder.createStringLiteral(ai->getLoc(), sliLeft->getValue(),
                                            StringLiteralInst::Encoding::UTF16);
   }

   if (sliRight->getEncoding() == StringLiteralInst::Encoding::UTF8
       && sliLeft->getEncoding() == StringLiteralInst::Encoding::UTF16) {
      funcResultType = aiLeft->getOperand(3);
      friConvertFromBuiltin = friLeft;
      // Convert UTF8 representation into UTF16.
      sliRight = builder.createStringLiteral(ai->getLoc(), sliRight->getValue(),
                                             StringLiteralInst::Encoding::UTF16);
   }

   // It should be impossible to have two operands with different
   // encodings at this point.
   assert(
      sliLeft->getEncoding() == sliRight->getEncoding()
      && "Both operands of string concatenation should have the same encoding");
}

/// Computes the length of a concatenated string literal.
APInt StringConcatenationOptimizer::getConcatenatedLength() {
   // Real length of string literals computed based on its contents.
   // Length is in code units.
   auto sliLenLeft = sliLeft->getCodeUnitCount();
   (void)sliLenLeft;
   auto sliLenRight = sliRight->getCodeUnitCount();
   (void)sliLenRight;

   // Length of string literals as reported by string.make functions.
   auto *lenLeft = dyn_cast<IntegerLiteralInst>(aiLeft->getOperand(2));
   auto *lenRight = dyn_cast<IntegerLiteralInst>(aiRight->getOperand(2));

   // Real and reported length should be the same.
   assert(sliLenLeft == lenLeft->getValue()
          && "Size of string literal in @_semantics(string.make) is wrong");

   assert(sliLenRight == lenRight->getValue()
          && "Size of string literal in @_semantics(string.make) is wrong");

   // Compute length of the concatenated literal.
   return lenLeft->getValue() + lenRight->getValue();
}

/// Computes the isAscii flag of a concatenated UTF8-encoded string literal.
bool StringConcatenationOptimizer::isAscii() const {
   // Add the isASCII argument in case of UTF8.
   // IsASCII is true only if IsASCII of both literals is true.
   auto *asciiLeft = dyn_cast<IntegerLiteralInst>(aiLeft->getOperand(3));
   auto *asciiRight = dyn_cast<IntegerLiteralInst>(aiRight->getOperand(3));
   auto isAsciiLeft = asciiLeft->getValue() == 1;
   auto isAsciiRight = asciiRight->getValue() == 1;
   return isAsciiLeft && isAsciiRight;
}

SingleValueInstruction *StringConcatenationOptimizer::optimize() {
   // Bail out if string literals concatenation optimization is
   // not possible.
   if (!extractStringConcatOperands())
      return nullptr;

   // Perform string literal encodings adjustments if needed.
   adjustEncodings();

   // Arguments of the new StringLiteralInst to be created.
   SmallVector<PILValue, 4> arguments;

   // Encoding to be used for the concatenated string literal.
   auto encoding = sliLeft->getEncoding();

   // Create a concatenated string literal.
   builder.setCurrentDebugScope(ai->getDebugScope());
   auto lv = sliLeft->getValue();
   auto rv = sliRight->getValue();
   auto *newSLI =
      builder.createStringLiteral(ai->getLoc(), lv + Twine(rv), encoding);
   arguments.push_back(newSLI);

   // Length of the concatenated literal according to its encoding.
   auto *len = builder.createIntegerLiteral(
      ai->getLoc(), aiLeft->getOperand(2)->getType(), getConcatenatedLength());
   arguments.push_back(len);

   // isAscii flag for UTF8-encoded string literals.
   if (encoding == StringLiteralInst::Encoding::UTF8) {
      bool ascii = isAscii();
      auto ilType = aiLeft->getOperand(3)->getType();
      auto *asciiLiteral =
         builder.createIntegerLiteral(ai->getLoc(), ilType, intmax_t(ascii));
      arguments.push_back(asciiLiteral);
   }

   // Type.
   arguments.push_back(funcResultType);

   return builder.createApply(ai->getLoc(), friConvertFromBuiltin,
                              SubstitutionMap(), arguments);
}

/// Top level entry point
SingleValueInstruction *polar::tryToConcatenateStrings(ApplyInst *ai,
                                                       PILBuilder &builder) {
   return StringConcatenationOptimizer(ai, builder).optimize();
}

//===----------------------------------------------------------------------===//
//                              Closure Deletion
//===----------------------------------------------------------------------===//

/// NOTE: Instructions with transitive ownership kind are assumed to not keep
/// the underlying closure alive as well. This is meant for instructions only
/// with non-transitive users.
static bool useDoesNotKeepClosureAlive(const PILInstruction *inst) {
   switch (inst->getKind()) {
      case PILInstructionKind::StrongRetainInst:
      case PILInstructionKind::StrongReleaseInst:
      case PILInstructionKind::DestroyValueInst:
      case PILInstructionKind::RetainValueInst:
      case PILInstructionKind::ReleaseValueInst:
      case PILInstructionKind::DebugValueInst:
      case PILInstructionKind::EndBorrowInst:
         return true;
      default:
         return false;
   }
}

static bool useHasTransitiveOwnership(const PILInstruction *inst) {
   // convert_escape_to_noescape is used to convert to a @noescape function type.
   // It does not change ownership of the function value.
   if (isa<ConvertEscapeToNoEscapeInst>(inst))
      return true;

   // Look through copy_value, begin_borrow. They are inert for our purposes, but
   // we need to look through it.
   return isa<CopyValueInst>(inst) || isa<BeginBorrowInst>(inst);
}

static PILValue createLifetimeExtendedAllocStack(
   PILBuilder &builder, PILLocation loc, PILValue arg,
   ArrayRef<PILBasicBlock *> exitingBlocks, InstModCallbacks callbacks) {
   AllocStackInst *asi = nullptr;
   {
      // Save our insert point and create a new alloc_stack in the initial BB and
      // dealloc_stack in all exit blocks.
      auto *oldInsertPt = &*builder.getInsertionPoint();
      builder.setInsertionPoint(builder.getFunction().begin()->begin());
      asi = builder.createAllocStack(loc, arg->getType());
      callbacks.createdNewInst(asi);

      for (auto *BB : exitingBlocks) {
         builder.setInsertionPoint(BB->getTerminator());
         callbacks.createdNewInst(builder.createDeallocStack(loc, asi));
      }
      builder.setInsertionPoint(oldInsertPt);
   }
   assert(asi != nullptr);

   // Then perform a copy_addr [take] [init] right after the partial_apply from
   // the original address argument to the new alloc_stack that we have
   // created.
   callbacks.createdNewInst(
      builder.createCopyAddr(loc, arg, asi, IsTake, IsInitialization));

   // Return the new alloc_stack inst that has the appropriate live range to
   // destroy said values.
   return asi;
}

static bool shouldDestroyPartialApplyCapturedArg(PILValue arg,
                                                 PILParameterInfo paramInfo,
                                                 const PILFunction &F) {
   // If we have a non-trivial type and the argument is passed in @inout, we do
   // not need to destroy it here. This is something that is implicit in the
   // partial_apply design that will be revisited when partial_apply is
   // redesigned.
   if (paramInfo.isIndirectMutating())
      return false;

   // If we have a trivial type, we do not need to put in any extra releases.
   if (arg->getType().isTrivial(F))
      return false;

   // We handle all other cases.
   return true;
}

// *HEY YOU, YES YOU, PLEASE READ*. Even though a textual partial apply is
// printed with the convention of the closed over function upon it, all
// non-inout arguments to a partial_apply are passed at +1. This includes
// arguments that will eventually be passed as guaranteed or in_guaranteed to
// the closed over function. This is because the partial apply is building up a
// boxed aggregate to send off to the closed over function. Of course when you
// call the function, the proper conventions will be used.
void polar::releasePartialApplyCapturedArg(PILBuilder &builder, PILLocation loc,
                                           PILValue arg,
                                           PILParameterInfo paramInfo,
                                           InstModCallbacks callbacks) {
   if (!shouldDestroyPartialApplyCapturedArg(arg, paramInfo,
                                             builder.getFunction()))
      return;

   // Otherwise, we need to destroy the argument. If we have an address, we
   // insert a destroy_addr and return. Any live range issues must have been
   // dealt with by our caller.
   if (arg->getType().isAddress()) {
      // Then emit the destroy_addr for this arg
      PILInstruction *newInst = builder.emitDestroyAddrAndFold(loc, arg);
      callbacks.createdNewInst(newInst);
      return;
   }

   // Otherwise, we have an object. We emit the most optimized form of release
   // possible for that value.

   // If we have qualified ownership, we should just emit a destroy value.
   if (arg->getFunction()->hasOwnership()) {
      callbacks.createdNewInst(builder.createDestroyValue(loc, arg));
      return;
   }

   if (arg->getType().hasReferenceSemantics()) {
      auto u = builder.emitStrongRelease(loc, arg);
      if (u.isNull())
         return;

      if (auto *SRI = u.dyn_cast<StrongRetainInst *>()) {
         callbacks.deleteInst(SRI);
         return;
      }

      callbacks.createdNewInst(u.get<StrongReleaseInst *>());
      return;
   }

   auto u = builder.emitReleaseValue(loc, arg);
   if (u.isNull())
      return;

   if (auto *rvi = u.dyn_cast<RetainValueInst *>()) {
      callbacks.deleteInst(rvi);
      return;
   }

   callbacks.createdNewInst(u.get<ReleaseValueInst *>());
}

/// For each captured argument of pai, decrement the ref count of the captured
/// argument as appropriate at each of the post dominated release locations
/// found by tracker.
static bool releaseCapturedArgsOfDeadPartialApply(PartialApplyInst *pai,
                                                  ReleaseTracker &tracker,
                                                  InstModCallbacks callbacks) {
   PILBuilderWithScope builder(pai);
   PILLocation loc = pai->getLoc();
   CanPILFunctionType paiTy =
      pai->getCallee()->getType().getAs<PILFunctionType>();

   ArrayRef<PILParameterInfo> params = paiTy->getParameters();
   llvm::SmallVector<PILValue, 8> args;
   for (PILValue v : pai->getArguments()) {
      // If any of our arguments contain open existentials, bail. We do not
      // support this for now so that we can avoid having to re-order stack
      // locations (a larger change).
      if (v->getType().hasOpenedExistential())
         return false;
      args.emplace_back(v);
   }
   unsigned delta = params.size() - args.size();
   assert(delta <= params.size()
          && "Error, more args to partial apply than "
             "params in its interface.");
   params = params.drop_front(delta);

   llvm::SmallVector<PILBasicBlock *, 2> exitingBlocks;
   pai->getFunction()->findExitingBlocks(exitingBlocks);

   // Go through our argument list and create new alloc_stacks for each
   // non-trivial address value. This ensures that the memory location that we
   // are cleaning up has the same live range as the partial_apply. Otherwise, we
   // may be inserting destroy_addr of alloc_stack that have already been passed
   // to a dealloc_stack.
   for (unsigned i : llvm::reverse(indices(args))) {
      PILValue arg = args[i];
      PILParameterInfo paramInfo = params[i];

      // If we are not going to destroy this partial_apply, continue.
      if (!shouldDestroyPartialApplyCapturedArg(arg, paramInfo,
                                                builder.getFunction()))
         continue;

      // If we have an object, we will not have live range issues, just continue.
      if (arg->getType().isObject())
         continue;

      // Now that we know that we have a non-argument address, perform a take-init
      // of arg into a lifetime extended alloc_stack
      args[i] = createLifetimeExtendedAllocStack(builder, loc, arg, exitingBlocks,
                                                 callbacks);
   }

   // Emit a destroy for each captured closure argument at each final release
   // point.
   for (auto *finalRelease : tracker.getFinalReleases()) {
      builder.setInsertionPoint(finalRelease);
      builder.setCurrentDebugScope(finalRelease->getDebugScope());
      for (unsigned i : indices(args)) {
         PILValue arg = args[i];
         PILParameterInfo param = params[i];

         releasePartialApplyCapturedArg(builder, loc, arg, param, callbacks);
      }
   }

   return true;
}

static bool
deadMarkDependenceUser(PILInstruction *inst,
                       SmallVectorImpl<PILInstruction *> &deleteInsts) {
   if (!isa<MarkDependenceInst>(inst))
      return false;
   deleteInsts.push_back(inst);
   for (auto *use : cast<SingleValueInstruction>(inst)->getUses()) {
      if (!deadMarkDependenceUser(use->getUser(), deleteInsts))
         return false;
   }
   return true;
}

/// TODO: Generalize this to general objects.
bool polar::tryDeleteDeadClosure(SingleValueInstruction *closure,
                                 InstModCallbacks callbacks) {
   auto *pa = dyn_cast<PartialApplyInst>(closure);

   // We currently only handle locally identified values that do not escape. We
   // also assume that the partial apply does not capture any addresses.
   if (!pa && !isa<ThinToThickFunctionInst>(closure))
      return false;

   // A stack allocated partial apply does not have any release users. Delete it
   // if the only users are the dealloc_stack and mark_dependence instructions.
   if (pa && pa->isOnStack()) {
      SmallVector<PILInstruction *, 8> deleteInsts;
      for (auto *use : pa->getUses()) {
         if (isa<DeallocStackInst>(use->getUser())
             || isa<DebugValueInst>(use->getUser()))
            deleteInsts.push_back(use->getUser());
         else if (!deadMarkDependenceUser(use->getUser(), deleteInsts))
            return false;
      }
      for (auto *inst : reverse(deleteInsts))
         callbacks.deleteInst(inst);
      callbacks.deleteInst(pa);

      // Note: the lifetime of the captured arguments is managed outside of the
      // trivial closure value i.e: there will already be releases for the
      // captured arguments. Releasing captured arguments is not necessary.
      return true;
   }

   // We only accept a user if it is an ARC object that can be removed if the
   // object is dead. This should be expanded in the future. This also ensures
   // that we are locally identified and non-escaping since we only allow for
   // specific ARC users.
   ReleaseTracker tracker(useDoesNotKeepClosureAlive, useHasTransitiveOwnership);

   // Find the ARC users and the final retain, release.
   if (!getFinalReleasesForValue(PILValue(closure), tracker))
      return false;

   // If we have a partial_apply, release each captured argument at each one of
   // the final release locations of the partial apply.
   if (auto *pai = dyn_cast<PartialApplyInst>(closure)) {
      // If we can not decrement the ref counts of the dead partial apply for any
      // reason, bail.
      if (!releaseCapturedArgsOfDeadPartialApply(pai, tracker, callbacks))
         return false;
   }

   // Then delete all user instructions in reverse so that leaf uses are deleted
   // first.
   for (auto *user : reverse(tracker.getTrackedUsers())) {
      assert(user->getResults().empty()
             || useHasTransitiveOwnership(user)
                && "We expect only ARC operations without "
                   "results or a cast from escape to noescape without users");
      callbacks.deleteInst(user);
   }

   // Finally delete the closure.
   callbacks.deleteInst(closure);

   return true;
}

bool polar::simplifyUsers(SingleValueInstruction *inst) {
   bool changed = false;

   for (auto ui = inst->use_begin(), ue = inst->use_end(); ui != ue;) {
      PILInstruction *user = ui->getUser();
      ++ui;

      auto svi = dyn_cast<SingleValueInstruction>(user);
      if (!svi)
         continue;

      PILValue S = simplifyInstruction(svi);
      if (!S)
         continue;

      replaceAllSimplifiedUsesAndErase(svi, S);
      changed = true;
   }

   return changed;
}

/// True if a type can be expanded without a significant increase to code size.
bool polar::shouldExpand(PILModule &module, PILType ty) {
   // FIXME: Expansion
   auto expansion = TypeExpansionContext::minimal();

   if (module.Types.getTypeLowering(ty, expansion).isAddressOnly()) {
      return false;
   }
   if (EnableExpandAll) {
      return true;
   }

   unsigned numFields = module.Types.countNumberOfFields(ty, expansion);
   return (numFields <= 6);
}

/// Some support functions for the global-opt and let-properties-opts

// Encapsulate the state used for recursive analysis of a static
// initializer. Discover all the instruction in a use-def graph and return them
// in topological order.
//
// TODO: We should have a DFS utility for this sort of thing so it isn't
// recursive.
class StaticInitializerAnalysis {
   SmallVectorImpl<PILInstruction *> &postOrderInstructions;
   llvm::SmallDenseSet<PILValue, 8> visited;
   int recursionLevel = 0;

public:
   StaticInitializerAnalysis(
      SmallVectorImpl<PILInstruction *> &postOrderInstructions)
      : postOrderInstructions(postOrderInstructions) {}

   // Perform a recursive DFS on on the use-def graph rooted at `V`. Insert
   // values in the `visited` set in preorder. Insert values in
   // `postOrderInstructions` in postorder so that the instructions are
   // topologically def-use ordered (in execution order).
   bool analyze(PILValue rootValue) {
      return recursivelyAnalyzeOperand(rootValue);
   }

protected:
   bool recursivelyAnalyzeOperand(PILValue v) {
      if (!visited.insert(v).second)
         return true;

      if (++recursionLevel > 50)
         return false;

      // TODO: For multi-result instructions, we could simply insert all result
      // values in the visited set here.
      auto *inst = dyn_cast<SingleValueInstruction>(v);
      if (!inst)
         return false;

      if (!recursivelyAnalyzeInstruction(inst))
         return false;

      postOrderInstructions.push_back(inst);
      --recursionLevel;
      return true;
   }

   bool recursivelyAnalyzeInstruction(PILInstruction *inst) {
      if (auto *si = dyn_cast<StructInst>(inst)) {
         // If it is not a struct which is a simple type, bail.
         if (!si->getType().isTrivial(*si->getFunction()))
            return false;

         return llvm::all_of(si->getAllOperands(), [&](Operand &operand) -> bool {
            return recursivelyAnalyzeOperand(operand.get());
         });
      }
      if (auto *ti = dyn_cast<TupleInst>(inst)) {
         // If it is not a tuple which is a simple type, bail.
         if (!ti->getType().isTrivial(*ti->getFunction()))
            return false;

         return llvm::all_of(ti->getAllOperands(), [&](Operand &operand) -> bool {
            return recursivelyAnalyzeOperand(operand.get());
         });
      }
      if (auto *bi = dyn_cast<BuiltinInst>(inst)) {
         switch (bi->getBuiltinInfo().ID) {
            case BuiltinValueKind::FPTrunc:
               if (auto *li = dyn_cast<LiteralInst>(bi->getArguments()[0])) {
                  return recursivelyAnalyzeOperand(li);
               }
               return false;
            default:
               return false;
         }
      }
      return isa<IntegerLiteralInst>(inst) || isa<FloatLiteralInst>(inst)
             || isa<StringLiteralInst>(inst);
   }
};

/// Check if the value of v is computed by means of a simple initialization.
/// Populate `forwardInstructions` with references to all the instructions
/// that participate in the use-def graph required to compute `V`. The
/// instructions will be in def-use topological order.
bool polar::analyzeStaticInitializer(
   PILValue v, SmallVectorImpl<PILInstruction *> &forwardInstructions) {
   return StaticInitializerAnalysis(forwardInstructions).analyze(v);
}

/// FIXME: This must be kept in sync with replaceLoadSequence()
/// below. What a horrible design.
bool polar::canReplaceLoadSequence(PILInstruction *inst) {
   if (auto *cai = dyn_cast<CopyAddrInst>(inst))
      return true;

   if (auto *li = dyn_cast<LoadInst>(inst))
      return true;

   if (auto *seai = dyn_cast<StructElementAddrInst>(inst)) {
      for (auto seaiUse : seai->getUses()) {
         if (!canReplaceLoadSequence(seaiUse->getUser()))
            return false;
      }
      return true;
   }

   if (auto *teai = dyn_cast<TupleElementAddrInst>(inst)) {
      for (auto teaiUse : teai->getUses()) {
         if (!canReplaceLoadSequence(teaiUse->getUser()))
            return false;
      }
      return true;
   }

   if (auto *ba = dyn_cast<BeginAccessInst>(inst)) {
      for (auto use : ba->getUses()) {
         if (!canReplaceLoadSequence(use->getUser()))
            return false;
      }
      return true;
   }

   // Incidental uses of an address are meaningless with regard to the loaded
   // value.
   if (isIncidentalUse(inst) || isa<BeginUnpairedAccessInst>(inst))
      return true;

   return false;
}

/// Replace load sequence which may contain
/// a chain of struct_element_addr followed by a load.
/// The sequence is traversed inside out, i.e.
/// starting with the innermost struct_element_addr
/// Move into utils.
///
/// FIXME: this utility does not make sense as an API. How can the caller
/// guarantee that the only uses of `I` are struct_element_addr and
/// tuple_element_addr?
void polar::replaceLoadSequence(PILInstruction *inst, PILValue value) {
   if (auto *cai = dyn_cast<CopyAddrInst>(inst)) {
      PILBuilder builder(cai);
      builder.createStore(cai->getLoc(), value, cai->getDest(),
                          StoreOwnershipQualifier::Unqualified);
      return;
   }

   if (auto *li = dyn_cast<LoadInst>(inst)) {
      li->replaceAllUsesWith(value);
      return;
   }

   if (auto *seai = dyn_cast<StructElementAddrInst>(inst)) {
      PILBuilder builder(seai);
      auto *sei =
         builder.createStructExtract(seai->getLoc(), value, seai->getField());
      for (auto seaiUse : seai->getUses()) {
         replaceLoadSequence(seaiUse->getUser(), sei);
      }
      return;
   }

   if (auto *teai = dyn_cast<TupleElementAddrInst>(inst)) {
      PILBuilder builder(teai);
      auto *tei =
         builder.createTupleExtract(teai->getLoc(), value, teai->getFieldNo());
      for (auto teaiUse : teai->getUses()) {
         replaceLoadSequence(teaiUse->getUser(), tei);
      }
      return;
   }

   if (auto *ba = dyn_cast<BeginAccessInst>(inst)) {
      for (auto use : ba->getUses()) {
         replaceLoadSequence(use->getUser(), value);
      }
      return;
   }

   // Incidental uses of an addres are meaningless with regard to the loaded
   // value.
   if (isIncidentalUse(inst) || isa<BeginUnpairedAccessInst>(inst))
      return;

   llvm_unreachable("Unknown instruction sequence for reading from a global");
}

/// Are the callees that could be called through Decl statically
/// knowable based on the Decl and the compilation mode?
bool polar::calleesAreStaticallyKnowable(PILModule &module, PILDeclRef decl) {
   if (decl.isForeign)
      return false;

   auto *afd = decl.getAbstractFunctionDecl();
   assert(afd && "Expected abstract function decl!");
   return calleesAreStaticallyKnowable(module, afd);
}

/// Are the callees that could be called through Decl statically
/// knowable based on the Decl and the compilation mode?
bool polar::calleesAreStaticallyKnowable(PILModule &module,
                                         AbstractFunctionDecl *afd) {
   const DeclContext *assocDC = module.getAssociatedContext();
   if (!assocDC)
      return false;

   // Only handle members defined within the PILModule's associated context.
   if (!afd->isChildContextOf(assocDC))
      return false;

   if (afd->isDynamic()) {
      return false;
   }

   if (!afd->hasAccess())
      return false;

   // Only consider 'private' members, unless we are in whole-module compilation.
   switch (afd->getEffectiveAccess()) {
      case AccessLevel::Open:
         return false;
      case AccessLevel::Public:
         if (isa<ConstructorDecl>(afd)) {
            // Constructors are special: a derived class in another module can
            // "override" a constructor if its class is "open", although the
            // constructor itself is not open.
            auto *nd = afd->getDeclContext()->getSelfNominalTypeDecl();
            if (nd->getEffectiveAccess() == AccessLevel::Open)
               return false;
         }
         LLVM_FALLTHROUGH;
      case AccessLevel::Interface:
         return module.isWholeModule();
      case AccessLevel::FilePrivate:
      case AccessLevel::Private:
         return true;
   }

   llvm_unreachable("Unhandled access level in switch.");
}

Optional<FindLocalApplySitesResult>
polar::findLocalApplySites(FunctionRefBaseInst *fri) {
   SmallVector<Operand *, 32> worklist(fri->use_begin(), fri->use_end());

   Optional<FindLocalApplySitesResult> f;
   f.emplace();

   // Optimistically state that we have no escapes before our def-use dataflow.
   f->escapes = false;

   while (!worklist.empty()) {
      auto *op = worklist.pop_back_val();
      auto *user = op->getUser();

      // If we have a full apply site as our user.
      if (auto apply = FullApplySite::isa(user)) {
         if (apply.getCallee() == op->get()) {
            f->fullApplySites.push_back(apply);
            continue;
         }
      }

      // If we have a partial apply as a user, start tracking it, but also look at
      // its users.
      if (auto *pai = dyn_cast<PartialApplyInst>(user)) {
         if (pai->getCallee() == op->get()) {
            // Track the partial apply that we saw so we can potentially eliminate
            // dead closure arguments.
            f->partialApplySites.push_back(pai);
            // Look to see if we can find a full application of this partial apply
            // as well.
            llvm::copy(pai->getUses(), std::back_inserter(worklist));
            continue;
         }
      }

      // Otherwise, see if we have any function casts to look through...
      switch (user->getKind()) {
         case PILInstructionKind::ThinToThickFunctionInst:
         case PILInstructionKind::ConvertFunctionInst:
         case PILInstructionKind::ConvertEscapeToNoEscapeInst:
            llvm::copy(cast<SingleValueInstruction>(user)->getUses(),
                       std::back_inserter(worklist));
            continue;

            // A partial_apply [stack] marks its captured arguments with
            // mark_dependence.
         case PILInstructionKind::MarkDependenceInst:
            llvm::copy(cast<SingleValueInstruction>(user)->getUses(),
                       std::back_inserter(worklist));
            continue;

            // Look through any reference count instructions since these are not
            // escapes:
         case PILInstructionKind::CopyValueInst:
            llvm::copy(cast<CopyValueInst>(user)->getUses(),
                       std::back_inserter(worklist));
            continue;
         case PILInstructionKind::StrongRetainInst:
         case PILInstructionKind::StrongReleaseInst:
         case PILInstructionKind::RetainValueInst:
         case PILInstructionKind::ReleaseValueInst:
         case PILInstructionKind::DestroyValueInst:
            // A partial_apply [stack] is deallocated with a dealloc_stack.
         case PILInstructionKind::DeallocStackInst:
            continue;
         default:
            break;
      }

      // But everything else is considered an escape.
      f->escapes = true;
   }

   // If we did escape and didn't find any apply sites, then we have no
   // information for our users that is interesting.
   if (f->escapes && f->partialApplySites.empty() && f->fullApplySites.empty())
      return None;
   return f;
}

/// Insert destroys of captured arguments of partial_apply [stack].
void polar::insertDestroyOfCapturedArguments(
   PartialApplyInst *pai, PILBuilder &builder,
   llvm::function_ref<bool(PILValue)> shouldInsertDestroy) {
   assert(pai->isOnStack());

   ApplySite site(pai);
   PILFunctionConventions calleeConv(site.getSubstCalleeType(),
                                     pai->getModule());
   auto loc = RegularLocation::getAutoGeneratedLocation();
   for (auto &arg : pai->getArgumentOperands()) {
      if (!shouldInsertDestroy(arg.get()))
         continue;
      unsigned calleeArgumentIndex = site.getCalleeArgIndex(arg);
      assert(calleeArgumentIndex >= calleeConv.getPILArgIndexOfFirstParam());
      auto paramInfo = calleeConv.getParamInfoForPILArg(calleeArgumentIndex);
      releasePartialApplyCapturedArg(builder, loc, arg.get(), paramInfo);
   }
}

AbstractFunctionDecl *polar::getBaseMethod(AbstractFunctionDecl *FD) {
   while (FD->getOverriddenDecl()) {
      FD = FD->getOverriddenDecl();
   }
   return FD;
}
